#!/usr/bin/env python3
#
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""HRW4U Language Server Protocol implementation."""

from __future__ import annotations

import json
import os
import sys
import urllib.parse
from functools import lru_cache
from typing import Any

from antlr4.error.ErrorListener import ErrorListener

from hrw4u.hrw4uLexer import hrw4uLexer
from hrw4u.hrw4uParser import hrw4uParser
from hrw4u.visitor import HRW4UVisitor
from hrw4u.common import create_parse_tree
from hrw4u.tables import FUNCTION_MAP, STATEMENT_FUNCTION_MAP
from hrw4u.states import SectionType
from hrw4u.types import VarType, LanguageKeyword

from hrw4u_lsp.lsp.strings import (
    StringLiteralHandler, ContextAnalyzer, ExpressionParser, DocumentAnalyzer, CompletionContext, LSPDiagnostic,
    VariableDeclaration)
from hrw4u_lsp.lsp.hover import (
    HoverInfoProvider, FunctionHoverProvider, VariableHoverProvider, SectionHoverProvider, RegexHoverProvider,
    ModifierHoverProvider, OperatorHoverProvider)
from hrw4u_lsp.lsp.completions import CompletionProvider
from hrw4u_lsp.lsp.documentation import LSP_FUNCTION_DOCUMENTATION


class LSPErrorListener(ErrorListener):

    def __init__(self) -> None:
        super().__init__()
        self.errors: list[dict[str, Any]] = []

    def syntaxError(self, _, offendingSymbol, line, column, msg, e) -> None:
        self.errors.append(
            {
                "range":
                    {
                        "start": {
                            "line": line - 1,
                            "character": column
                        },
                        "end":
                            {
                                "line": line - 1,
                                "character": column + (len(str(offendingSymbol.text)) if offendingSymbol.text else 1)
                            }
                    },
                "severity": 1,
                "message": msg,
                "source": "hrw4u-parser"
            })


class DocumentManager:
    """Manages document synchronization and parsing."""

    def __init__(self) -> None:
        self.documents: dict[str, str] = {}
        self.diagnostics: dict[str, list[LSPDiagnostic]] = {}
        self.variable_declarations: dict[str, dict[str, VariableDeclaration]] = {}
        self._completion_provider = CompletionProvider()
        self._uri_path_cache: dict[str, str] = {}

    def _add_operator_completions(self, completions: list, base_prefix: str, current_section, context: CompletionContext) -> None:
        """Add operator and condition completions."""
        operator_completions = self._completion_provider.get_operator_completions(
            base_prefix, current_section, context["replacement_range"])
        completions.extend(operator_completions)

    def _add_function_completions(self, completions: list, function_map: dict, function_type: str) -> None:
        """Add function completions using centralized provider."""
        completions.extend(self._completion_provider.get_function_completions())

    def _add_keyword_completions(self, completions: list) -> None:
        """Add keyword completions using centralized provider."""
        completions.extend(self._completion_provider.get_keyword_completions())

    def open_document(self, uri: str, text: str) -> None:
        """Open/update a document."""
        self.documents[uri] = text
        self._analyze_document(uri)

    def change_document(self, uri: str, text: str) -> None:
        """Handle document changes."""
        self.documents[uri] = text
        self._analyze_document(uri)

    def close_document(self, uri: str) -> None:
        """Close a document."""
        if uri in self.documents:
            del self.documents[uri]
        if uri in self.diagnostics:
            del self.diagnostics[uri]
        if uri in self.variable_declarations:
            del self.variable_declarations[uri]

    def _uri_to_path(self, uri: str) -> str:
        """Convert a URI to a file path with caching."""
        if uri in self._uri_path_cache:
            return self._uri_path_cache[uri]

        if uri.startswith("file://"):
            path = urllib.parse.unquote(uri[7:])
        else:
            path = uri

        self._uri_path_cache[uri] = path
        return path

    def _analyze_document(self, uri: str) -> None:
        """Analyze document and collect diagnostics."""
        text = self.documents.get(uri, "")
        diagnostics = []

        self.variable_declarations[uri] = DocumentAnalyzer.parse_variable_declarations(text)
        diagnostics.extend(DocumentAnalyzer.validate_section_names(text))

        try:
            filename = self._uri_to_path(uri)
            tree, parser, error_collector = create_parse_tree(text, filename, hrw4uLexer, hrw4uParser, "HRW4U", collect_errors=True)

            # Collect parser syntax errors (will be deduplicated later)
            parser_errors = []
            if hasattr(parser, '_syntax_errors'):
                parser_errors = list(parser._syntax_errors)

            if tree is not None:
                visitor = HRW4UVisitor(filename=filename, error_collector=error_collector)
                try:
                    visitor.visit(tree)
                except Exception as e:
                    diagnostics.append(
                        {
                            "range": {
                                "start": {
                                    "line": 0,
                                    "character": 0
                                },
                                "end": {
                                    "line": 0,
                                    "character": 10
                                }
                            },
                            "severity": 1,
                            "message": f"Visitor error: {str(e)}",
                            "source": "hrw4u-visitor"
                        })

            # Collect semantic errors first (they're more helpful with suggestions)
            semantic_error_positions = set()
            if error_collector and error_collector.has_errors():
                for error_obj in error_collector.errors:
                    error_str = str(error_obj)
                    if hasattr(error_obj, '__notes__') and error_obj.__notes__:
                        error_str += "\n" + "\n".join(error_obj.__notes__)

                    severity = 1

                    line_num = getattr(error_obj, 'line', 1) - 1 if hasattr(error_obj, 'line') else 0
                    col_num = getattr(error_obj, 'column', 0) if hasattr(error_obj, 'column') else 0

                    line_num = max(0, line_num)
                    col_num = max(0, col_num)

                    semantic_error_positions.add((line_num, col_num))

                    lines = text.split('\n')
                    if line_num < len(lines):
                        line_text = lines[line_num]
                        end_col = col_num + 1

                        # Check if this error has a suggestion - if so, adjust range accordingly
                        if hasattr(error_obj, '__notes__') and error_obj.__notes__:
                            for note in error_obj.__notes__:
                                if "Did you mean:" in note:
                                    suggestion_part = note.split("Did you mean: ")[1].rstrip("?")
                                    if suggestion_part and col_num < len(line_text):
                                        start = col_num
                                        segments = suggestion_part.split('.')
                                        pos = start

                                        for i, _ in enumerate(segments):
                                            # Skip the current word
                                            while pos < len(line_text) and (line_text[pos].isalnum() or line_text[pos] == '_'):
                                                pos += 1

                                            if i < len(segments) - 1:
                                                if pos < len(line_text) and line_text[pos] == '.':
                                                    pos += 1
                                                else:
                                                    break

                                        end_col = pos
                                    break
                        else:
                            # No suggestion - expand to full token as before
                            if col_num < len(line_text):
                                start = col_num
                                end = col_num
                                while end < len(line_text) and (line_text[end].isalnum() or line_text[end] in '._'):
                                    end += 1
                                if end > start:
                                    end_col = end
                    else:
                        end_col = col_num + 1

                    diagnostics.append(
                        {
                            "range":
                                {
                                    "start": {
                                        "line": line_num,
                                        "character": col_num
                                    },
                                    "end": {
                                        "line": line_num,
                                        "character": end_col
                                    }
                                },
                            "severity": severity,
                            "message": error_str,
                            "source": "hrw4u"
                        })

            # Add parser errors only if they don't overlap with semantic errors
            for parser_error in parser_errors:
                error_line = parser_error.get("range", {}).get("start", {}).get("line", -1)
                error_char = parser_error.get("range", {}).get("start", {}).get("character", -1)

                # Only add parser error if there's no semantic error at the same position
                if (error_line, error_char) not in semantic_error_positions:
                    diagnostics.append(parser_error)

        except Exception as e:
            diagnostics.append(
                {
                    "range": {
                        "start": {
                            "line": 0,
                            "character": 0
                        },
                        "end": {
                            "line": 0,
                            "character": 10
                        }
                    },
                    "severity": 1,
                    "message": f"Parse error: {str(e)}",
                    "source": "hrw4u-parser"
                })

        # Convert generic diagnostics to typed diagnostics where possible
        typed_diagnostics: list[LSPDiagnostic] = []
        for diag in diagnostics:
            if isinstance(diag, dict):
                # For backward compatibility with generic dict diagnostics
                typed_diagnostics.append(diag)  # type: ignore
            else:
                typed_diagnostics.append(diag)
        self.diagnostics[uri] = typed_diagnostics

    def get_diagnostics(self, uri: str) -> list[LSPDiagnostic]:
        return self.diagnostics.get(uri, [])

    def get_completions(self, uri: str, line: int, character: int) -> list[dict[str, Any]]:
        text = self.documents.get(uri, "")
        lines = text.split('\n')

        if line >= len(lines):
            return []

        context: CompletionContext = ContextAnalyzer.determine_completion_context(lines, line, character)
        completions = []

        if context["is_section_context"]:
            completions.extend(self._completion_provider.get_section_completions())

        # Dot-notation completions (operators and conditions)
        elif context["has_dot"]:
            base_prefix = context["dot_prefix"]
            current_section = context["current_section"]
            self._add_operator_completions(completions, base_prefix, current_section, context)

        elif context["is_function_context"]:
            self._add_function_completions(completions, FUNCTION_MAP, "Function")
            self._add_function_completions(completions, STATEMENT_FUNCTION_MAP, "Statement")

        # Variable type completions (in VARS context)
        if context["current_section"] and context["current_section"].value == "VARS":
            completions.extend(self._completion_provider.get_variable_type_completions())

        if context["allows_keywords"]:
            self._add_keyword_completions(completions)

        return completions

    def get_hover_info(self, uri: str, line: int, character: int) -> dict[str, Any] | None:
        """Get hover information for the symbol at the given position."""
        text = self.documents.get(uri, "")
        lines = text.split('\n')

        if line >= len(lines):
            return None

        current_line = lines[line]
        if character >= len(current_line):
            return None

        comment_start = current_line.find('#')
        if comment_start != -1 and character >= comment_start:
            return None

        string_info = StringLiteralHandler.check_string_literal(current_line, character)
        if string_info:
            return string_info

        # Check for condition modifiers (NOCASE, MID, etc.) used with 'with' keyword
        modifier_info = ModifierHoverProvider.get_modifier_hover_info(current_line, character)
        if modifier_info:
            return modifier_info

        full_expression_info = ExpressionParser.parse_dotted_expression(current_line, character)
        if full_expression_info:
            return full_expression_info

        word_start = character
        word_end = character

        # Expand to find word boundaries (including hyphens for header names)
        while word_start > 0 and (current_line[word_start - 1].isalnum() or current_line[word_start - 1] in '._@-'):
            word_start -= 1
        while word_end < len(current_line) and (current_line[word_end].isalnum() or current_line[word_end] in '._@-'):
            word_end += 1

        if word_start == word_end:
            return None

        word = current_line[word_start:word_end]

        # Don't show hover for basic control flow keywords
        keywords_dict = LanguageKeyword.get_keywords_with_descriptions()
        if word.lower() in keywords_dict:
            return None

        # Check if we're inside a regex pattern
        regex_info = RegexHoverProvider.get_regex_hover_info(current_line, character)
        if regex_info:
            return regex_info

        # Check if this is a section declaration (SECTION_NAME { or SECTION_NAME{)
        stripped_line = current_line.strip()
        if stripped_line.startswith(word):
            remainder = stripped_line[len(word):].lstrip()
            if remainder.startswith('{'):
                return SectionHoverProvider.get_section_hover_info(word)

        # Check if word is a variable type
        var_type_names = {var_type.value[0] for var_type in VarType}
        if word.lower() in var_type_names:
            return VariableHoverProvider.get_variable_type_hover_info(word)

        rest_of_line = current_line[word_end:].strip()
        if rest_of_line.startswith('('):
            return FunctionHoverProvider.get_function_hover_info(word)

        variable_info = VariableHoverProvider.get_variable_hover_info(self.variable_declarations, uri, word)
        if variable_info:
            return variable_info

        # Check if this is an operator/condition before falling back to generic symbol
        operator_info = OperatorHoverProvider.get_operator_hover_info(word)
        if operator_info:
            return operator_info

        return None


class HRW4ULanguageServer:
    """Main LSP server implementation."""

    def __init__(self) -> None:
        self.document_manager = DocumentManager()
        self.running = False
        self.request_id = 0

    def start(self) -> None:
        self.running = True
        while self.running:
            try:
                content_length = 0
                while True:
                    line = sys.stdin.readline()
                    if not line:
                        self.running = False
                        return

                    line = line.strip()
                    if line.startswith("Content-Length:"):
                        content_length = int(line.split(":")[1].strip())
                    elif line == "":
                        break

                if content_length > 0:
                    content = sys.stdin.read(content_length)
                    if not content:
                        self.running = False
                        break

                    try:
                        message = json.loads(content)
                        self._handle_message(message)
                    except json.JSONDecodeError as e:
                        self._send_error_response(None, -32700, f"Parse error: {str(e)}")

            except (EOFError, KeyboardInterrupt):
                self.running = False
                break
            except Exception as e:
                self._send_error_response(None, -32603, f"Internal error: {str(e)}")

    def _handle_message(self, message: dict[str, Any]) -> None:
        method = message.get("method", "")

        if method == "initialize":
            self._handle_initialize(message)
        elif method == "initialized":
            pass
        elif method == "textDocument/didOpen":
            self._handle_did_open(message)
        elif method == "textDocument/didChange":
            self._handle_did_change(message)
        elif method == "textDocument/didClose":
            self._handle_did_close(message)
        elif method == "textDocument/completion":
            self._handle_completion(message)
        elif method == "textDocument/hover":
            self._handle_hover(message)
        elif method == "textDocument/codeAction":
            self._handle_code_action(message)
        elif method == "shutdown":
            self._handle_shutdown(message)
        elif method == "exit":
            self.running = False

    def _handle_initialize(self, message: dict[str, Any]) -> None:
        response = {
            "jsonrpc": "2.0",
            "id": message["id"],
            "result":
                {
                    "capabilities":
                        {
                            "textDocumentSync": 1,
                            "completionProvider": {
                                "triggerCharacters": ["."]
                            },
                            "hoverProvider": True,
                            "codeActionProvider": True
                        }
                }
        }
        self._send_message(response)

    def _handle_did_open(self, message: dict[str, Any]) -> None:
        params = message["params"]
        text_document = params["textDocument"]
        uri = text_document["uri"]
        text = text_document["text"]

        self.document_manager.open_document(uri, text)
        self._send_diagnostics(uri)

    def _handle_did_change(self, message: dict[str, Any]) -> None:
        params = message["params"]
        text_document = params["textDocument"]
        uri = text_document["uri"]
        changes = params["contentChanges"]

        if changes:
            text = changes[0]["text"]
            self.document_manager.change_document(uri, text)
            self._send_diagnostics(uri)

    def _handle_did_close(self, message: dict[str, Any]) -> None:
        params = message["params"]
        text_document = params["textDocument"]
        uri = text_document["uri"]

        self.document_manager.close_document(uri)

    def _handle_completion(self, message: dict[str, Any]) -> None:
        params = message["params"]
        text_document = params["textDocument"]
        position = params["position"]
        uri = text_document["uri"]
        line = position["line"]
        character = position["character"]

        completions = self.document_manager.get_completions(uri, line, character)

        response = {"jsonrpc": "2.0", "id": message["id"], "result": {"isIncomplete": False, "items": completions}}
        self._send_message(response)

    def _handle_hover(self, message: dict[str, Any]) -> None:
        params = message["params"]
        text_document = params["textDocument"]
        position = params["position"]
        uri = text_document["uri"]
        line = position["line"]
        character = position["character"]

        hover_info = self.document_manager.get_hover_info(uri, line, character)

        response = {"jsonrpc": "2.0", "id": message["id"], "result": hover_info}
        self._send_message(response)

    def _handle_code_action(self, message: dict[str, Any]) -> None:
        """Handle code action requests for quick fixes."""
        params = message["params"]
        text_document = params["textDocument"]
        range_obj = params["range"]
        context = params.get("context", {})
        uri = text_document["uri"]

        actions = []

        # Get diagnostics for this range
        diagnostics = context.get("diagnostics", [])
        for diagnostic in diagnostics:
            if diagnostic.get("source") == "hrw4u" and "Did you mean:" in diagnostic.get("message", ""):
                message_text = diagnostic["message"]
                if "Did you mean:" in message_text:
                    suggestion_part = message_text.split("Did you mean: ")[1].split("\n")[0].rstrip("?")
                    if suggestion_part:
                        action = {
                            "title": f"Replace with '{suggestion_part}'",
                            "kind": "quickfix",
                            "diagnostics": [diagnostic],
                            "edit": {
                                "changes": {
                                    uri: [{
                                        "range": diagnostic["range"],
                                        "newText": suggestion_part
                                    }]
                                }
                            }
                        }
                        actions.append(action)

        response = {"jsonrpc": "2.0", "id": message["id"], "result": actions}
        self._send_message(response)

    def _handle_shutdown(self, message: dict[str, Any]) -> None:
        response = {"jsonrpc": "2.0", "id": message["id"], "result": None}
        self._send_message(response)

    def _send_diagnostics(self, uri: str) -> None:
        """Send diagnostics for a document."""
        diagnostics = self.document_manager.get_diagnostics(uri)

        notification = {
            "jsonrpc": "2.0",
            "method": "textDocument/publishDiagnostics",
            "params": {
                "uri": uri,
                "diagnostics": diagnostics
            }
        }
        self._send_message(notification)

    def _send_message(self, message: dict[str, Any]) -> None:
        content = json.dumps(message)
        response = f"Content-Length: {len(content)}\r\n\r\n{content}"
        sys.stdout.write(response)
        sys.stdout.flush()

    def _send_error_response(self, request_id: int | None, code: int, message: str) -> None:
        response = {"jsonrpc": "2.0", "id": request_id, "error": {"code": code, "message": message}}
        self._send_message(response)


def main() -> None:
    """Main entry point for the LSP server."""
    server = HRW4ULanguageServer()
    server.start()


if __name__ == "__main__":
    main()
